
import torch
import numpy as np
import torch.nn.functional as F


def feat_extractor(model, data_loader, logger=None):
    model.eval()
    feats_mp = list()
    feats3 = list()
    feats = list()
    centers = list()

    for i, batch in enumerate(data_loader):
        imgs = batch[0].cuda()
        with torch.no_grad():
            # feats_all = model.headfeature(model.backbone(imgs))
            out3, feats_all, center3, _ = model.backbone(imgs)
            #
            feats_local2 = model.localhead2(feats_all)
            feats_local3 = model.localhead3(out3)
            #
            #
            feats_local2 = F.avg_pool2d(feats_local2, (1, 1), 1, 0).squeeze(2).squeeze(2)
            # feats_local2 = feats_local2.reshape(batch[0].size(0), -1)
            feats_local3 = F.avg_pool2d(feats_local3, (8, 8), 8, 0).squeeze(2).squeeze(2)
            # feats_local3 = feats_local3.reshape(batch[0].size(0), -1)

            # feats_local2 = F.normalize(feats_local2, p=2, dim=1)
            feats_local3 = F.normalize(feats_local3, p=2, dim=1)

            feats_all = model.finalhead(feats_local3)
            feats_all = F.normalize(feats_all, p=2, dim=1)
            # feats_all = feats_local2

            # feats2.append(feats_local2.data.cpu().numpy())
            feats3.append(feats_local3.data.cpu().numpy())
            # feats.append(out.data.cpu().numpy())
            feats.append(feats_all.data.cpu().numpy())

        if logger is not None and (i + 1) % 100 == 0:
            logger.debug(f'Extract Features: [{i + 1}/{len(data_loader)}]')
        del feats_local3,out3
    feats3 = np.vstack(feats3)
    feats = np.vstack(feats)
    return feats3,feats, feats_mp, centers
